/*
Copyright 2022 The Katalyst Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package state

import (
	"fmt"
	"math"
	"strings"
	"time"

	"k8s.io/apimachinery/pkg/util/sets"
	"k8s.io/klog/v2"

	"github.com/kubewharf/katalyst-core/pkg/agent/qrm-plugins/commonstate"
	cpuconsts "github.com/kubewharf/katalyst-core/pkg/agent/qrm-plugins/cpu/consts"
	"github.com/kubewharf/katalyst-core/pkg/agent/qrm-plugins/util/preoccupation"
	"github.com/kubewharf/katalyst-core/pkg/util/general"
	"github.com/kubewharf/katalyst-core/pkg/util/machine"
)

type GetContainerRequestedCoresFunc func(allocationInfo *AllocationInfo) float64

var (
	// StaticPools are generated by cpu plugin statically,
	// and they will be ignored when reading cpu advisor list and watch response.
	StaticPools = sets.NewString(
		commonstate.PoolNameReserve,
	)

	// ResidentPools are guaranteed existing in state,
	// and they are usually used to ensure stability.
	ResidentPools = sets.NewString(
		commonstate.PoolNameReclaim,
	).Union(StaticPools).Union(ForbiddenPools)

	// ForbiddenPools forbidden from being allocated to user containers and
	// is mainly used to perform specific tasks
	ForbiddenPools = sets.NewString(
		commonstate.PoolNameInterrupt,
	)
)

// GetUnitedPoolsCPUs returns the union of the specified pools' cpus.
func GetUnitedPoolsCPUs(poolsName sets.String, entries PodEntries) (machine.CPUSet, error) {
	unitedPoolsCPUs := machine.NewCPUSet()
	for _, poolName := range poolsName.List() {
		cpus, err := entries.GetCPUSetForPool(poolName)
		if err != nil && !strings.Contains(err.Error(), commonstate.PoolNotFoundErrMsg) {
			return unitedPoolsCPUs, err
		} else {
			general.Warningf("the current pool %s does not exist", poolName)
		}

		unitedPoolsCPUs = unitedPoolsCPUs.Union(cpus)
	}
	return unitedPoolsCPUs, nil
}

// WrapAllocationMetaFilter takes a filter function that operates on
// AllocationMeta and returns a wrapper function that applies the same filter
// to an AllocationInfo by extracting its AllocationMeta.
func WrapAllocationMetaFilter(metaFilter func(meta *commonstate.AllocationMeta) bool) func(info *AllocationInfo) bool {
	return func(info *AllocationInfo) bool {
		if info == nil {
			return false // Handle nil cases safely.
		}
		return metaFilter(&info.AllocationMeta)
	}
}

// WrapAllocationMetaFilterWithAnnotations takes a filter function that operates on
// AllocationMeta and returns a wrapper function that applies the same filter
// to an AllocationInfo by extracting its AllocationMeta and input annotations of candidate.
func WrapAllocationMetaFilterWithAnnotations(
	metaFilter func(meta *commonstate.AllocationMeta, annotations map[string]string) bool,
) func(info *AllocationInfo, annotations map[string]string) bool {
	return func(info *AllocationInfo, annotations map[string]string) bool {
		if info == nil {
			return false // Handle nil cases safely.
		}
		return metaFilter(&info.AllocationMeta, annotations)
	}
}

// GetIsolatedQuantityMapFromPodEntries returns a map to indicates isolation info,
// and the map is formatted as pod -> container -> isolated-quantity
func GetIsolatedQuantityMapFromPodEntries(podEntries PodEntries, ignoreAllocationInfos []*AllocationInfo, getContainerRequestedCores GetContainerRequestedCoresFunc) map[string]map[string]int {
	ret := make(map[string]map[string]int)
	for podUID, entries := range podEntries {
		if entries.IsPoolEntry() {
			continue
		}

	containerLoop:
		for containerName, allocationInfo := range entries {
			// only filter dedicated_cores without numa_binding
			if allocationInfo == nil || allocationInfo.CheckDedicatedNUMABinding() || !allocationInfo.CheckDedicated() {
				continue
			}

			for _, ignoreAllocationInfo := range ignoreAllocationInfos {
				if allocationInfo.PodUid == ignoreAllocationInfo.PodUid && allocationInfo.ContainerName == ignoreAllocationInfo.ContainerName {
					continue containerLoop
				}
			}

			// if there is no more cores to allocate, we will put dedicated_cores without numa_binding
			// to pool rather than isolation. calling this function means we will start to adjust allocation,
			// and we will try to isolate those containers, so we will treat them as containers to be isolated.
			var quantity int
			if allocationInfo.OwnerPoolName != commonstate.PoolNameDedicated {
				quantity = int(math.Ceil(getContainerRequestedCores(allocationInfo)))
			} else {
				quantity = allocationInfo.AllocationResult.Size()
			}
			if quantity == 0 {
				klog.Warningf("[GetIsolatedQuantityMapFromPodEntries] isolated pod: %s/%s container: %s get zero quantity",
					allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName)
				continue
			}

			if ret[podUID] == nil {
				ret[podUID] = make(map[string]int)
			}
			ret[podUID][containerName] = quantity
		}
	}
	return ret
}

// GetSharedQuantityMapFromPodEntries returns a map to indicates quantity info for each shared pool,
// and the map is formatted as pool -> quantity
func GetSharedQuantityMapFromPodEntries(podEntries PodEntries, ignoreAllocationInfos []*AllocationInfo, getContainerRequestedCores GetContainerRequestedCoresFunc) (map[string]map[int]int, error) {
	poolsQuantityMap := make(map[string]map[int]int)
	allocationInfosToCount := make([]*AllocationInfo, 0, len(podEntries))
	for _, entries := range podEntries {
		if entries.IsPoolEntry() {
			continue
		}

	containerLoop:
		for _, allocationInfo := range entries {
			// only count shared_cores not isolated.
			// if there is no more cores to allocate, we will put dedicated_cores without numa_binding to pool rather than isolation.
			// calling this function means we will start to adjust allocation, and we will try to isolate those containers,
			// so we will treat them as containers to be isolated.
			if allocationInfo == nil || !allocationInfo.CheckShared() || !allocationInfo.CheckMainContainer() {
				continue
			}

			for _, ignoreAllocationInfo := range ignoreAllocationInfos {
				if allocationInfo.PodUid == ignoreAllocationInfo.PodUid && allocationInfo.ContainerName == ignoreAllocationInfo.ContainerName {
					continue containerLoop
				}
			}

			allocationInfosToCount = append(allocationInfosToCount, allocationInfo)
		}
	}

	err := CountAllocationInfosToPoolsQuantityMap(allocationInfosToCount, poolsQuantityMap, getContainerRequestedCores)
	if err != nil {
		return nil, fmt.Errorf("CountAllocationInfosToPoolsQuantityMap faild with error: %v", err)
	}

	return poolsQuantityMap, nil
}

// GetNonBindingSharedRequestedQuantityFromPodEntries returns total quantity shared_cores without numa_binding requested
func GetNonBindingSharedRequestedQuantityFromPodEntries(podEntries PodEntries, newNonBindingSharedRequestedQuantity map[string]float64, getContainerRequestedCores GetContainerRequestedCoresFunc) int {
	var reqFloat64 float64 = 0

	for podUid, entries := range podEntries {
		if entries.IsPoolEntry() {
			continue
		}

		// ignore new coming pods (only for inplace update)
		if newNonBindingSharedRequestedQuantity != nil {
			if _, ok := newNonBindingSharedRequestedQuantity[podUid]; ok {
				continue
			}
		}

		for _, allocationInfo := range entries {
			if allocationInfo == nil || !allocationInfo.CheckShared() || allocationInfo.CheckNUMABinding() {
				continue
			}

			reqFloat64 += getContainerRequestedCores(allocationInfo)
		}
	}

	for podUid := range newNonBindingSharedRequestedQuantity {
		reqFloat64 += newNonBindingSharedRequestedQuantity[podUid]
	}

	return CPUPreciseCeil(reqFloat64)
}

// GetRequestedQuantityFromPodEntries returns total quantity of reclaim without numa_binding requested
func GetRequestedQuantityFromPodEntries(podEntries PodEntries, allocationFilter func(info *AllocationInfo) bool,
	getContainerRequestedCores GetContainerRequestedCoresFunc,
) float64 {
	var reqFloat64 float64 = 0

	for _, entries := range podEntries {
		if entries.IsPoolEntry() {
			continue
		}

		for _, allocationInfo := range entries {
			if allocationInfo == nil || !allocationFilter(allocationInfo) {
				continue
			}

			reqFloat64 += getContainerRequestedCores(allocationInfo)
		}
	}

	return reqFloat64
}

func GetReclaimedNUMAHeadroom(numaHeadroom map[int]float64, numaSet machine.CPUSet) float64 {
	res := float64(0)

	for _, numaID := range numaSet.ToSliceNoSortInt() {
		res += numaHeadroom[numaID]
	}

	return res
}

// GenerateMachineStateFromPodEntries for dynamic policy
func GenerateMachineStateFromPodEntries(topology *machine.CPUTopology, podEntries PodEntries, originMachineState NUMANodeMap) (NUMANodeMap, error) {
	currentMachineState, err := GenerateMachineStateFromPodEntriesByPolicy(topology, podEntries, cpuconsts.CPUResourcePluginPolicyNameDynamic)
	if err != nil {
		return nil, err
	}

	updateMachineStatePreOccPodEntries(currentMachineState, originMachineState)
	return currentMachineState, nil
}

// updateMachineStatePreOccPodEntries update the pre-occupation pod from pod entries and origin machine state
func updateMachineStatePreOccPodEntries(currentMachineState, originMachineState NUMANodeMap) {
	// override pre-occupation pod from pod entries
	now := time.Now()
	for numaID, numaState := range currentMachineState {
		preOccPodEntries := make(PodEntries)
		originalPodEntries := make(PodEntries)
		if originState, ok := originMachineState[numaID]; ok {
			if originState.PodEntries != nil {
				originalPodEntries = originState.PodEntries
			}
			if originState.PreOccPodEntries != nil {
				preOccPodEntries = originState.PreOccPodEntries
			}
		}

		for podUID, containerEntries := range originalPodEntries {
			// skip pod that already in current machine state
			if _, ok := numaState.PodEntries[podUID]; ok {
				continue
			}

			for containerName, allocationInfo := range containerEntries {
				if allocationInfo == nil {
					general.Warningf("nil allocationInfo in podEntries")
					continue
				}

				// skip unneeded pre-occupation allocation info
				if !preoccupation.PreOccAllocationFilter(allocationInfo.AllocationMeta) {
					continue
				}

				if _, ok := preOccPodEntries[podUID]; !ok {
					preOccPodEntries[podUID] = make(ContainerEntries)
				}

				if _, ok := preOccPodEntries[podUID][containerName]; !ok {
					preOccPodEntries[podUID][containerName] = allocationInfo
				}
			}
		}

		for podUID, containerEntries := range preOccPodEntries {
			for containerName, preOccAllocationInfo := range containerEntries {
				if preOccAllocationInfo == nil {
					general.Warningf("nil preOccAllocationInfo in podEntries")
					continue
				}

				if preoccupation.PreOccAllocationExpired(preOccAllocationInfo.AllocationMeta, now) {
					numaState.DeletePreOccAllocationInfo(podUID, containerName)
				} else {
					preoccupation.SetPreOccDeleteTimestamp(&preOccAllocationInfo.AllocationMeta, now)
					numaState.SetPreOccAllocationInfo(podUID, containerName, preOccAllocationInfo)
				}
			}
		}
	}
}

func GetCPUIncrRatio(allocationInfo *AllocationInfo) float64 {
	if allocationInfo.CheckSharedNUMABinding() {
		// multiply incrRatio for numa_binding shared_cores to allow it burst
		return cpuconsts.CPUIncrRatioSharedCoresNUMABinding
	}

	return cpuconsts.CPUIncrRatioDefault
}

func GetSharedBindingNUMAsFromQuantityMap(poolsQuantityMap map[string]map[int]int) sets.Int {
	res := sets.NewInt()

	for _, quantityMap := range poolsQuantityMap {
		for numaID, quantity := range quantityMap {
			if numaID != commonstate.FakedNUMAID && quantity > 0 {
				res.Insert(numaID)
			}
		}
	}

	return res
}

func CountAllocationInfosToPoolsQuantityMap(allocationInfos []*AllocationInfo,
	poolsQuantityMap map[string]map[int]int,
	getContainerRequestedCores GetContainerRequestedCoresFunc,
) error {
	if poolsQuantityMap == nil {
		return fmt.Errorf("nil poolsQuantityMap in CountAllocationInfosToPoolsQuantityMap")
	}

	precisePoolsQuantityMap := make(map[string]map[int]float64)

	for _, allocationInfo := range allocationInfos {
		if allocationInfo == nil {
			return fmt.Errorf("CountAllocationInfosToPoolsQuantityMap got nil allocationInfo")
		}

		reqFloat64 := getContainerRequestedCores(allocationInfo) * GetCPUIncrRatio(allocationInfo)

		var targetNUMAID int
		var poolName string

		if allocationInfo.CheckSharedNUMABinding() {
			var numaSet machine.CPUSet
			poolName = allocationInfo.GetOwnerPoolName()

			if poolName == commonstate.EmptyOwnerPoolName {
				var pErr error
				poolName, pErr = allocationInfo.GetSpecifiedNUMABindingPoolName()
				if pErr != nil {
					return fmt.Errorf("GetSpecifiedNUMABindingPoolName for %s/%s/%s failed with error: %v",
						allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName, pErr)
				}

				numaSet, pErr = machine.Parse(allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint])
				if pErr != nil {
					return fmt.Errorf("parse numaHintStr: %s failed with error: %v",
						allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint], pErr)
				}

				general.Infof(" %s/%s/%s count to specified NUMA binding pool name: %s, numaSet: %s",
					allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName, poolName, numaSet.String())
			} else {
				// already in a valid pool (numa aware pool or isolation pool)
				numaSet = allocationInfo.GetAllocationResultNUMASet()

				general.Infof(" %s/%s/%s count to non-empty owner pool name: %s, numaSet: %s",
					allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName, poolName, numaSet.String())
			}

			if numaSet.Size() != 1 {
				return fmt.Errorf("numaHintStr: %s indicates invalid numaSet size for numa_binding shared_cores",
					allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint])
			}

			targetNUMAID = numaSet.ToSliceNoSortInt()[0]

			if targetNUMAID < 0 {
				return fmt.Errorf("numaHintStr: %s indicates invalid numaSet numa_binding shared_cores",
					allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint])
			}
		} else {
			targetNUMAID = commonstate.FakedNUMAID
			poolName = allocationInfo.GetPoolName()
		}

		if poolName == commonstate.EmptyOwnerPoolName {
			return fmt.Errorf("get poolName failed for %s/%s/%s",
				allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName)
		}

		curLen := len(precisePoolsQuantityMap[poolName])
		if curLen > 1 {
			return fmt.Errorf("pool %s cross NUMA: %+v", poolName, precisePoolsQuantityMap[poolName])
		} else if curLen == 1 {
			for numaID := range precisePoolsQuantityMap[poolName] {
				if numaID != targetNUMAID {
					return fmt.Errorf("pool %s cross NUMA: %d, %d", poolName, numaID, targetNUMAID)
				}
			}
		} else {
			precisePoolsQuantityMap[poolName] = make(map[int]float64)
		}

		// no need to compare pools quantity in specific NUMA with NUMA all CPUs,
		// we will do it in generatePoolsAndIsolation
		precisePoolsQuantityMap[poolName][targetNUMAID] += reqFloat64
	}

	for poolName, preciseQuantityMap := range precisePoolsQuantityMap {
		for numaID, preciseQuantity := range preciseQuantityMap {
			if poolsQuantityMap[poolName] == nil {
				poolsQuantityMap[poolName] = make(map[int]int)
			}

			poolsQuantityMap[poolName][numaID] += CPUPreciseCeil(preciseQuantity)

			// return err will abort the procedure,
			// so there is no need to revert modifications made in parameter poolsQuantityMap
			if len(poolsQuantityMap[poolName]) > 1 {
				return fmt.Errorf("pool %s cross NUMA: %+v", poolName, poolsQuantityMap[poolName])
			}
		}
	}

	return nil
}

func GetSharedNUMABindingTargetNuma(allocationInfo *AllocationInfo) (int, error) {
	var numaSet machine.CPUSet
	poolName := allocationInfo.GetOwnerPoolName()

	if poolName == commonstate.EmptyOwnerPoolName {
		var pErr error
		poolName, pErr = allocationInfo.GetSpecifiedNUMABindingPoolName()
		if pErr != nil {
			return commonstate.FakedNUMAID, fmt.Errorf("GetSpecifiedNUMABindingPoolName for %s/%s/%s failed with error: %v",
				allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName, pErr)
		}

		numaSet, pErr = machine.Parse(allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint])
		if pErr != nil {
			return commonstate.FakedNUMAID, fmt.Errorf("parse numaHintStr: %s failed with error: %v",
				allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint], pErr)
		}

		general.Infof(" %s/%s/%s count to specified NUMA binding pool name: %s, numaSet: %s",
			allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName, poolName, numaSet.String())
	} else {
		// already in a valid pool (numa aware pool or isolation pool)
		numaSet = allocationInfo.GetAllocationResultNUMASet()

		general.Infof(" %s/%s/%s count to non-empty owner pool name: %s, numaSet: %s",
			allocationInfo.PodNamespace, allocationInfo.PodName, allocationInfo.ContainerName, poolName, numaSet.String())
	}

	if numaSet.Size() != 1 {
		return commonstate.FakedNUMAID, fmt.Errorf("numaHintStr: %s indicates invalid numaSet size for numa_binding shared_cores",
			allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint])
	}

	targetNUMAID := numaSet.ToSliceNoSortInt()[0]

	if targetNUMAID < 0 {
		return commonstate.FakedNUMAID, fmt.Errorf("numaHintStr: %s indicates invalid numaSet numa_binding shared_cores",
			allocationInfo.Annotations[cpuconsts.CPUStateAnnotationKeyNUMAHint])
	}

	return targetNUMAID, nil
}

func CheckAllocationInfoTopologyAwareAssignments(ai1, ai2 *AllocationInfo) bool {
	return checkCPUSetMap(ai1.TopologyAwareAssignments, ai2.TopologyAwareAssignments)
}

func CheckAllocationInfoOriginTopologyAwareAssignments(ai1, ai2 *AllocationInfo) bool {
	return checkCPUSetMap(ai1.OriginalTopologyAwareAssignments, ai2.OriginalTopologyAwareAssignments)
}

func checkCPUSetMap(map1, map2 map[int]machine.CPUSet) bool {
	if len(map1) != len(map2) {
		return false
	}
	for numaNode, cset := range map1 {
		if !map2[numaNode].Equals(cset) {
			return false
		}
	}
	return true
}

// CPUPreciseCeil we can not use math.Ceil directly here, because the cpu requests are stored using floats,
// there is a chance of precision issues during addition calculations.
// in critical case:
// - the allocatable cpu of the node is 122
// - the sum of allocated cpu requests is 118.00000000000001 (after ceil is 119),
// - the new pod request is 4
// 119 + 4 > 122, so qrm will reject the new pod.
func CPUPreciseCeil(request float64) int {
	return int(math.Ceil(float64(int(request*1000)) / 1000))
}

// GenerateMachineStateFromPodEntriesByPolicy returns NUMANodeMap for given resource based on
// machine info and reserved resources along with existed pod entries and policy name
// todo: extracting entire state package as a common standalone utility
func GenerateMachineStateFromPodEntriesByPolicy(topology *machine.CPUTopology, podEntries PodEntries, policyName string) (NUMANodeMap, error) {
	if topology == nil {
		return nil, fmt.Errorf("GenerateMachineStateFromPodEntriesByPolicy got nil topology")
	}

	machineState := make(NUMANodeMap)
	for _, numaNode := range topology.CPUDetails.NUMANodes().ToSliceInt64() {
		numaNodeState := &NUMANodeState{}
		numaNodeAllCPUs := topology.CPUDetails.CPUsInNUMANodes(int(numaNode))
		allocatedCPUsInNumaNode := machine.NewCPUSet()

		for podUID, containerEntries := range podEntries {
			if containerEntries.IsPoolEntry() {
				continue
			}
			for containerName, allocationInfo := range containerEntries {
				if allocationInfo == nil {
					general.Warningf("nil allocationInfo in podEntries")
					continue
				}

				// the container hasn't cpuset assignment in the current NUMA node
				if allocationInfo.OriginalTopologyAwareAssignments[int(numaNode)].Size() == 0 &&
					allocationInfo.TopologyAwareAssignments[int(numaNode)].Size() == 0 {
					continue
				}

				switch policyName {
				case cpuconsts.CPUResourcePluginPolicyNameDynamic:
					// only modify allocated and default properties in NUMA node state if the policy is dynamic and the entry indicates numa_binding.
					// shared_cores with numa_binding also contributes to numaNodeState.AllocatedCPUSet,
					// it's convenient that we can skip NUMA with AllocatedCPUSet > 0 when allocating CPUs for dedicated_cores with numa_binding.
					if allocationInfo.CheckDedicatedNUMABinding() {
						allocatedCPUsInNumaNode = allocatedCPUsInNumaNode.Union(allocationInfo.OriginalTopologyAwareAssignments[int(numaNode)])
					}
				case cpuconsts.CPUResourcePluginPolicyNameNative:
					// only modify allocated and default properties in NUMA node state if the policy is native and the QoS class is Guaranteed
					if allocationInfo.CheckDedicatedPool() {
						allocatedCPUsInNumaNode = allocatedCPUsInNumaNode.Union(allocationInfo.OriginalTopologyAwareAssignments[int(numaNode)])
					}
				}

				allocationResult := allocationInfo.AllocationResult.Intersection(numaNodeAllCPUs)
				originalAllocationResult := allocationInfo.OriginalAllocationResult.Intersection(numaNodeAllCPUs)

				topologyAwareAssignments := map[int]machine.CPUSet{int(numaNode): allocationResult}
				originalTopologyAwareAssignments := map[int]machine.CPUSet{int(numaNode): originalAllocationResult}

				numaNodeAllocationInfo := allocationInfo.Clone()
				numaNodeAllocationInfo.AllocationResult = allocationResult
				numaNodeAllocationInfo.OriginalAllocationResult = originalAllocationResult
				numaNodeAllocationInfo.TopologyAwareAssignments = topologyAwareAssignments
				numaNodeAllocationInfo.OriginalTopologyAwareAssignments = originalTopologyAwareAssignments

				numaNodeState.SetAllocationInfo(podUID, containerName, numaNodeAllocationInfo)
			}
		}

		numaNodeState.AllocatedCPUSet = allocatedCPUsInNumaNode
		numaNodeState.DefaultCPUSet = numaNodeAllCPUs.Difference(numaNodeState.AllocatedCPUSet)
		machineState[int(numaNode)] = numaNodeState
	}
	return machineState, nil
}
